import os
import pandas as pd
import time
import argparse
from utils import load_config, call_api, call_api_stream

def run_benchmark(config):
    """Run benchmark"""
    # Create output directory
    os.makedirs(config["output_dir"], exist_ok=True)
    
    # Load benchmark data
    bench_file = config["bench_file"]
    df_bench = pd.read_csv(bench_file)
    
    # Filter cases based on configuration (index 0 = case 1)
    if config["cases"] and len(config["cases"]) > 0:
        # Convert case numbers to 0-based indices
        case_indices = [case_num - 1 for case_num in config["cases"]]
        df_bench = df_bench.iloc[case_indices]
    
    # Result file path
    model_name = config["model"].replace("/", "-")
    res_file = os.path.join(config["output_dir"], f"{os.path.basename(bench_file).replace('.csv', '')}_{model_name}.csv")
    
    # Create results DataFrame
    results = []
    total_cost = 0
    
    # Process each case sequentially
    for index, row in df_bench.iterrows():
        case_num = index + 1  # Convert 0-based index to 1-based case number
        prompt = row['prompt']
        label = row['character2']
        option1 = row['option1']
        option2 = row['option2']
        option3 = row['option3']
        option4 = row['option4']
        
        print(f"Processing case {case_num} (index {index})")
        
        # Try calling API multiple times
        for attempt in range(config["max_retries"]):
            try:
                if config["is_reasoning_model"]:
                    res, response, res_json, tokens, cost = call_api_stream(
                        prompt, config, option1, option2, option3, option4, label
                    )
                else:
                    res, response, res_json, tokens, cost = call_api(
                        prompt, config, option1, option2, option3, option4, label
                    )
                
                success = int(res == label)
                total_cost += cost
                
                # Save result
                result = {
                    'case_num': case_num,
                    'prompt': prompt,
                    'gt': label,
                    'option1': option1,
                    'option2': option2,
                    'option3': option3,
                    'option4': option4,
                    'res': res,
                    'response': response,
                    'success': success,
                    'cost': cost,
                    'tokens': tokens,
                    'prob1': res_json[option1],
                    'prob2': res_json[option2],
                    'prob3': res_json[option3],
                    'prob4': res_json[option4]
                }
                
                results.append(result)
                
                print(f"  Result: {'Correct' if success else 'Incorrect'}, Cost: {cost:.2f}")

                # Create results DataFrame and save
                res_df = pd.DataFrame(results)
                res_df.to_csv(res_file, index=False)
                break
                
            except Exception as e:
                print(f"  Attempt {attempt+1}/{config['max_retries']} failed: {e}")
                time.sleep(5)
        else:
            print(f"  All attempts failed, skipping this case")
    
    # Print detailed results for each case
    print("\n" + "="*80)
    print(f"DETAILED RESULTS FOR MODEL: {config['model']}")
    print("="*80)
    
    for _, result in res_df.iterrows():
        print(f"\nCase {result['case_num']}:")
        print(f"  Option 1 ({result['option1']}): {result['prob1']}")
        print(f"  Option 2 ({result['option2']}): {result['prob2']}")
        print(f"  Option 3 ({result['option3']}): {result['prob3']}")
        print(f"  Option 4 ({result['option4']}): {result['prob4']}")
        print(f"  Correct answer: {result['gt']}")
        print(f"  Model prediction: {result['res']}")
        print(f"  Result: {'✓ CORRECT' if result['success'] else '✗ INCORRECT'}")
    
    print("\n" + "="*80)
    print(f"Total cost: {total_cost}")
    print("="*80)

def main():
    """Main function"""
    # Accept config path as a positional argument
    parser = argparse.ArgumentParser(description='Run PersonaEval benchmark')
    parser.add_argument('config_path', nargs='?', default='config.json', help='Path to config file')
    args = parser.parse_args()
    
    # Load configuration
    config = load_config(args.config_path)
    
    print(f"Using model: {config['model']}")
    print(f"Running cases: {config['cases']} (corresponding to indices {[case-1 for case in config['cases']]})")
    
    # Run benchmark
    run_benchmark(config)

if __name__ == "__main__":
    main()